Summary of this document

This note demonstrates how to - prepare some data in R - use the R package reticulate to pass the data to python - fit a numpyro poisson mixed model - plot the MCMC output using arviz

Setting up reticulate

library(reticulate)
use_condaenv("anaconda3", "/home/jmellor/anaconda3/bin/conda")
matplotlib <- import("matplotlib")
matplotlib$use("Agg", force = TRUE)

Preparing toy data in R

N_indivs <- 1000 # num individuals
intervals.mean <- 20
beta <- c(-0.5, 0) # log rate ratio parameters
gamma <- 0.7 # coefficient for effect of true value of biomarker

## y must be integer

## linear mixed model
## fixed effects
## random intercept
## random slope with time
## iid residuals

## poisson model for time to failure
## time-varying effect is the true and unobserved value of the marker at time t.  

## simulate individual baseline log hazard rates from a normal distribution
logbaseline <- rnorm(N_indivs, -3, 1)

## simulate individual means for biomarkers from a normal distribution
mean.z <- rnorm(N_indivs, 0, 1)

## simulate fixed covariate
x1.indiv <- rnorm(N_indivs, 0, 1) # fixed covariate
x2.indiv <- rnorm(N_indivs, 0, 1) # fixed covariate
U <- 3 # num covariates including intercept
## generate random number of person-time intervals for each individual
T <- 1 + rpois(n=N_indivs, lambda=intervals.mean)
N <- sum(T)

rate <- exp(logbaseline +
                beta[1] * x1.indiv + beta[2] * x2.indiv)

mixdata <- NULL
for(i in 1:N_indivs) { # loop over individuals to simulate from model
    mixdata <- rbind(mixdata,
                     data.frame(indiv=as.factor(rep(i, T[i])),
                                time=1:T[i],
                                logbaseline.indiv=rep(logbaseline[i], T[i]), 
                                x1=rep(x1.indiv[i], T[i]),
                                x2=rep(x2.indiv[i], T[i]),
                                z.true=rep(mean.z[i], T[i]),
                                z.obs=rnorm(T[i], mean.z[i]) # simulate measured values 
                                )) 
}

rate <- with(mixdata, exp(logbaseline.indiv + gamma * z.true + 
                beta[1] * x1 + beta[2] * x2))

mixdata$y <- rpois(n=N, lambda=rate)

scale_icept <- 10 # SD of prior on coeffs
mean_icept <- 0

X <- model.matrix(object =  y ~ x1 +  x2, data=mixdata)[, -1]
print(head(X))
##         x1        x2
## 1 0.466426 0.5964039
## 2 0.466426 0.5964039
## 3 0.466426 0.5964039
## 4 0.466426 0.5964039
## 5 0.466426 0.5964039
## 6 0.466426 0.5964039
X <- scale(X, center=TRUE, scale=FALSE)
U <- ncol(X)
indiv <- as.integer(mixdata$indiv)
N <- nrow(X)

testdata <- list(N=N, U=U, N_indivs=N_indivs, indiv=indiv,
                  mean_icept=mean_icept,
                  scale_icept=scale_icept,
                  scale_other=3,
                  y=mixdata$y, X=X) 

Passing data from R to Python

Here we make use of the r object to access the data.

import warnings
warnings.filterwarnings("ignore")
import numpy as np

def get_data():
    data = r.testdata
    data['indiv'] = np.array([d-1 for d in data['indiv']])
    data['y'] = np.array(data['y'])
    data['N_indivs'] = int(data['N_indivs'])
    return data

data = get_data()

Fitting a numpyro model

import jax.numpy as jnp
import os
import time as tm
import jax.random as random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.infer import MCMC, NUTS

def model(X, y, mean_icept, scale_icept, scale_other, N_indivs, indiv):
   alpha = numpyro.sample('alpha', dist.Normal(mean_icept, scale_icept))
   theta = numpyro.sample('theta', dist.Normal(jnp.zeros((X.shape[1], 1)), scale_other))
   xi = numpyro.sample('xi', dist.Normal(jnp.zeros(N_indivs), jnp.ones(N_indivs)))
   sigma_indiv = numpyro.sample('sigma_indiv', dist.HalfCauchy(5.))
   w = alpha + xi*sigma_indiv
   eta = X@theta + w[indiv].reshape(-1, 1)
   eta = eta.flatten()
   with numpyro.plate('data', X.shape[0]):
       numpyro.sample('obs', dist.Poisson(jnp.exp(eta)), obs=y)

def fit(data):
    rng_key = random.PRNGKey(0)
    rng_key, rng_key_ = random.split(rng_key)

    numpyro.set_platform('gpu')
    numpyro.set_host_device_count(4)
    
    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_samples=1000, num_warmup=1000, num_chains=4)
    mcmc.run(rng_key_,
             X=data['X'],
             y=data['y'],
             mean_icept=data['mean_icept'],
             scale_icept=data['scale_icept'],
             scale_other=data['scale_other'],
             N_indivs=data['N_indivs'],
             indiv=data['indiv']
             )
    return mcmc

mcmc = fit(data)

Plotting the results using Arviz

import arviz as az
import matplotlib.pyplot as plt

inf_data = az.from_numpyro(mcmc)
az.plot_trace(inf_data)
plt.show()